#
# H2O Grid Support
#
# Provides a set of functions to launch a grid search and get
# its results.

#-------------------------------------
# Grid-related functions start here :)
#-------------------------------------

#'
#' Launch grid search with given algorithm and parameters.
#'
#' @param algorithm  Name of algorithm to use in grid search (gbm, randomForest, kmeans, glm, deeplearning, naivebayes, pca).
#' @param grid_id  (Optional) ID for resulting grid search. If it is not specified then it is autogenerated.
#' @param x (Optional) A vector containing the names or indices of the predictor variables to use in building the model.
#'        If x is missing, then all columns except y are used.
#' @param y The name or column index of the response variable in the data. The response must be either a numeric or a
#'        categorical/factor variable. If the response is numeric, then a regression model will be trained, otherwise it will train a classification model.
#' @param training_frame Id of the training data frame.
#' @param ...  arguments describing parameters to use with algorithm (i.e., x, y, training_frame).
#'        Look at the specific algorithm - h2o.gbm, h2o.glm, h2o.kmeans, h2o.deepLearning - for available parameters.
#' @param hyper_params  List of lists of hyper parameters (i.e., \code{list(ntrees=c(1,2), max_depth=c(5,7))}).
#' @param is_supervised [Deprecated] It is not possible to override default behaviour. (Optional) If specified then override the default heuristic which decides if the given algorithm
#'        name and parameters specify a supervised or unsupervised algorithm.
#' @param do_hyper_params_check  Perform client check for specified hyper parameters. It can be time expensive for
#'        large hyper space.
#' @param search_criteria  (Optional)  List of control parameters for smarter hyperparameter search.  The list can 
#'        include values for: strategy, max_models, max_runtime_secs, stopping_metric, stopping_tolerance, stopping_rounds and
#'        seed.  The default strategy 'Cartesian' covers the entire space of hyperparameter combinations.  If you want to use
#'        cartesian grid search, you can leave the search_criteria argument unspecified. Specify the "RandomDiscrete" strategy
#'        to get random search of all the combinations of your hyperparameters with three ways of specifying when to stop the
#'        search: max number of models, max time, and metric-based early stopping (e.g., stop if MSE has not improved by 0.0001
#'        over the 5 best models). Examples below:
#'        \code{list(strategy = "RandomDiscrete", max_runtime_secs = 600, max_models = 100, stopping_metric = "AUTO",
#'        stopping_tolerance = 0.00001, stopping_rounds = 5, seed = 123456)} or \code{list(strategy = "RandomDiscrete", 
#'        max_models = 42, max_runtime_secs = 28800)} or \code{list(strategy = "RandomDiscrete", stopping_metric = "AUTO", 
#'        stopping_tolerance = 0.001, stopping_rounds = 10)} or \code{list(strategy = "RandomDiscrete", stopping_metric = 
#'        "misclassification", stopping_tolerance = 0.00001, stopping_rounds = 5)}.
#' @param export_checkpoints_dir Directory to automatically export grid and its models to.
#' @param recovery_dir When specified the grid and all necessary data (frames, models) will be saved to this
#'        directory (use HDFS or other distributed file-system). Should the cluster crash during training, the grid
#'        can be reloaded from this directory via \code{h2o.loadGrid} and training can be resumed
#' @param parallelism Level of Parallelism during grid model building. 1 = sequential building (default).
#'        Use the value of 0 for adaptive parallelism - decided by H2O. Any number > 1 sets the exact number of models built in parallel.
#' @importFrom jsonlite toJSON
#' @examples
#' \dontrun{
#' library(h2o)
#' library(jsonlite)
#' h2o.init()
#' iris_hf <- as.h2o(iris)
#' grid <- h2o.grid("gbm", x = c(1:4), y = 5, training_frame = iris_hf,
#'                  hyper_params = list(ntrees = c(1, 2, 3)))
#' # Get grid summary
#' summary(grid)
#' # Fetch grid models
#' model_ids <- grid@@model_ids
#' models <- lapply(model_ids, function(id) { h2o.getModel(id)})
#' }
#' @export
h2o.grid <- function(algorithm,
                     grid_id,
                     x,
                     y,
                     training_frame,
                     ...,
                     hyper_params = list(),
                     is_supervised = NULL,
                     do_hyper_params_check = FALSE,
                     search_criteria = NULL,
                     export_checkpoints_dir = NULL,
                     recovery_dir = NULL,
                     parallelism = 1)
{
  if (!is.null(is_supervised)) {
    warning("Parameter is_supervised is deprecated. It is not possible to override default behaviour.")
  }
  #Unsupervised algos to account for in grid (these algos do not need response)
  unsupervised_algos <- c("kmeans", "pca", "svd", "glrm", "extendedisolationforest")
  # Parameter list
  dots <- list(...)
  # Add x, y, and training_frame
  if(!(algorithm %in% c(unsupervised_algos, toupper(unsupervised_algos)))) {
    if(!missing(y)) {
      dots$y <- y
    } else {
      # deeplearning with autoencoder param set to T is also okay.  Check this case before whining
      if (!((algorithm %in% c("deeplearning") && dots$autoencoder==TRUE))) { # only complain if not DL autoencoder
        stop("Must specify response, y")
      }
    }
  } 
  if(!missing(training_frame)) {
    dots$training_frame <- training_frame
  } else {
    stop("Must specify training frame, training_frame")
  }
  # If x is missing, then assume user wants to use all columns as features for supervised models only
  if(!(algorithm %in% c(unsupervised_algos, toupper(unsupervised_algos)))) {
    if (missing(x)) {
      if (is.numeric(y)) {
        dots$x <- setdiff(col(training_frame), y)
      } else {
        dots$x <- setdiff(colnames(training_frame), y)
      }
    } else {
      dots$x <- x
    }
  }
  if(algorithm %in% c("upliftdrf")){
      if(is.null(dots$treatment_column)) {
        stop("Must specify treatment column")
      }
  }  
  algorithm <- .h2o.unifyAlgoName(algorithm)
  model_param_names <- names(dots)
  hyper_param_names <- names(hyper_params)
  # Reject overlapping definition of parameters, this part is now done in Java backend
#   if (any(model_param_names %in% hyper_param_names)) {
#     overlapping_params <- intersect(model_param_names, hyper_param_names)
#     stop(paste0("The following parameters are defined as common model parameters and also as hyper parameters: ",
#                 .collapse(overlapping_params), "! Please choose only one way!"))
#   }
  # Get model builder parameters for this model
  all_params <- .h2o.getModelParameters(algo = algorithm)

  # Prepare model parameters
  params <- .h2o.prepareModelParameters(algo = algorithm, params = dots, is_supervised = is_supervised)
  # Validation of input key
  .key.validate(params$key_value)
  # Validate all hyper parameters against REST API end-point
  if (do_hyper_params_check) {
    lparams <- params
    # Generate all combination of hyper parameters
    expanded_grid <- expand.grid(lapply(hyper_params, function(o) { 1:length(o) }))
    # Get algo REST version
    algo_rest_version <- .h2o.getAlgoVersion(algo = algorithm)
    # Verify each defined point in hyper space against REST API
    apply(expanded_grid,
          MARGIN = 1,
          FUN = function(permutation) {
      # Fill hyper parameters for this permutation
      hparams <- lapply(hyper_param_names, function(name) { hyper_params[[name]][[permutation[[name]]]] })
      names(hparams) <- hyper_param_names
      params_for_validation <- lapply(append(lparams, hparams), function(x) { if(is.integer(x)) x <- as.numeric(x); x })
      # We have to repeat part of work used by model builders
      params_for_validation <- .h2o.checkAndUnifyModelParameters(algo = algorithm, allParams = all_params, params = params_for_validation)
      .h2o.validateModelParameters(algorithm, params_for_validation, h2oRestApiVersion = algo_rest_version)
    })
  }

  # Verify and unify the parameters
  params <- .h2o.checkAndUnifyModelParameters(algo = algorithm, allParams = all_params,
                                                  params = params, hyper_params = hyper_params)
  # Validate and unify hyper parameters
  hyper_values <- .h2o.checkAndUnifyHyperParameters(algo = algorithm,
                                                        allParams = all_params, hyper_params = hyper_params,
                                                        do_hyper_params_check = do_hyper_params_check)
  # Append grid parameters in JSON form
  params$hyper_parameters <- toJSON(hyper_values, digits=99)
  
  # Set directory for checkpoints export
  if(!is.null(export_checkpoints_dir)) {
    params$export_checkpoints_dir <- export_checkpoints_dir
  }
  if(!is.null(recovery_dir)) {
    params$recovery_dir <- recovery_dir
  }
  if(!is.null(parallelism)) {
    params$parallelism <- parallelism
  }

  if( !is.null(search_criteria)) {
      # Append grid search criteria in JSON form. 
      # jsonlite unfortunately doesn't handle scalar values so we need to serialize ourselves.
      keys <- paste0("\"", names(search_criteria), "\"", "=")
      vals <- lapply(search_criteria, function(val) { if(is.numeric(val)) val else paste0("\"", val, "\"") })
      body <- paste0(paste0(keys, vals), collapse=",")
      js <- paste0("{", body, "}", collapse="")
      params$search_criteria <- js
  }

  # Append grid_id if it is specified
  if (!missing(grid_id)) params$grid_id <- grid_id

  # Trigger grid search job
  res <- .h2o.__remoteSend(.h2o.__GRID(algorithm), h2oRestApiVersion = 99, .params = params, method = "POST")
  grid_id <- res$job$dest$name
  job_key <- res$job$key$name
  # Wait for grid job to finish
  .h2o.__waitOnJob(job_key)

  h2o.getGrid(grid_id = grid_id)
}

#'
#' Resume previously stopped grid training.
#'
#' @param grid_id ID of existing grid search
#' @param recovery_dir When specified the grid and all necessary data (frames, models) will be saved to this
#'        directory (use HDFS or other distributed file-system). Should the cluster crash during training, the grid
#'        can be reloaded from this directory via \code{h2o.loadGrid} and training can be resumed
#' @param ...  Additional parameters to modify the resumed Grid.
#' @export
h2o.resumeGrid <- function(grid_id, recovery_dir=NULL, ...) {
    grid <- h2o.getGrid(grid_id = grid_id)
    model_id <- grid@model_ids[[1]]
    model <- h2o.getModel(model_id = model_id)
    algorithm <- model@algorithm
    params <- list(...)
    detach <- params$detach
    params$detach <- NULL
    params$grid_id <- grid_id
    params$recovery_dir <- recovery_dir
    res <- .h2o.__remoteSend(.h2o.__GRID_RESUME(algorithm), h2oRestApiVersion = 99, .params = params, method = "POST")
    grid_id <- res$job$dest$name
    if (is.null(detach) || !detach) {
        # Wait for grid job to finish
        job_key <- res$job$key$name
        .h2o.__waitOnJob(job_key)
        h2o.getGrid(grid_id = grid_id)
    } else {
        grid_id
    }
}

#' Get a grid object from H2O distributed K/V store. 
#' 
#' Note that if neither cross-validation nor a 
#' validation frame is used in the grid search, then the training metrics will display in the 
#' "get grid" output. If a validation frame is passed to the grid, and nfolds = 0, then the 
#' validation metrics will display. However, if nfolds > 1, then cross-validation metrics will 
#' display even if a validation frame is provided.
#'
#' @param grid_id  ID of existing grid object to fetch
#' @param sort_by Sort the models in the grid space by a metric. Choices are "logloss", "residual_deviance", "mse", "auc", "accuracy", "precision", "recall", "f1", etc.
#' @param decreasing Specify whether sort order should be decreasing
#' @param verbose Controls verbosity of the output, if enabled prints out error messages for failed models (default: FALSE)
#' @examples
#' \dontrun{
#' library(h2o)
#' library(jsonlite)
#' h2o.init()
#' iris_hf <- as.h2o(iris)
#' h2o.grid("gbm", grid_id = "gbm_grid_id", x = c(1:4), y = 5,
#'          training_frame = iris_hf, hyper_params = list(ntrees = c(1, 2, 3)))
#' grid <- h2o.getGrid("gbm_grid_id")
#' # Get grid summary
#' summary(grid)
#' # Fetch grid models
#' model_ids <- grid@@model_ids
#' models <- lapply(model_ids, function(id) { h2o.getModel(id)})
#' }
#' @export
h2o.getGrid <- function(grid_id, sort_by, decreasing, verbose = FALSE) {
  json <- .h2o.__remoteSend(method = "GET", h2oRestApiVersion = 99, .h2o.__GRIDS(grid_id, sort_by, decreasing))
  class <- "H2OGrid"
  grid_id <- json$grid_id$name
  model_ids <- lapply(json$model_ids, function(model_id) { model_id$name })
  hyper_names <- lapply(json$hyper_names, function(name) { name })
  failed_params <- lapply(json$failed_params, function(param) {
                          x <- if (all(is.null(param) | is.na(param))) NULL else param
                          x
                        })
  failure_details <- lapply(json$failure_details, function(msg) { msg })
  failure_stack_traces <- lapply(json$failure_stack_traces, function(msg) { msg })
  failed_raw_params <- if (is.list(json$failed_raw_params)) matrix(nrow=0, ncol=0) else json$failed_raw_params
  warning_details <- lapply(json$warning_details, function(msg) { msg })

  # print out the failure/warning messages from Java if it exists
  if (length(warning_details) > 0)  {
    for (index in 1:length(warning_details)) {
      warning(warning_details[[index]])
    }
  }
  if (length(failure_details) > 0) {
    warning("Some models were not built due to a failure, for more details run `summary(grid_object, show_stack_traces = TRUE)`")
    if (verbose) {
      for (index in 1:length(failure_details)) {
        if (typeof(failed_params[[index]]) == "list") {
          for (index2 in 1:length(hyper_names)) {
            cat(sprintf("Hyper-parameter: %s, %s\n", hyper_names[[index2]], failed_params[[index]][[hyper_names[[index2]]]]))
          }
        }
        cat(sprintf("[%s] failure_details: %s \n", Sys.time(), failure_details[index]))
        cat(sprintf("[%s] failure_stack_traces: %s \n", Sys.time(), failure_stack_traces[index]))
      }
    }
  }

  new(class,
      grid_id = grid_id,
      model_ids = model_ids,
      hyper_names = hyper_names,
      failed_params = failed_params,
      failure_details = failure_details,
      failure_stack_traces = failure_stack_traces,
      failed_raw_params = failed_raw_params,
      summary_table     = json$summary_table
      )
}
